-
-
Notifications
You must be signed in to change notification settings - Fork 985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introducing pyro.infer.predictive.WeighedPredictive which reports weights along with predicted samples #3345
Conversation
Hi @BenZickel I think it's a great idea to add a Predictive interface that can incorporate both a set of samples and some sort importance weights. And I appreciate your efforts to share interface and code with
cc @martinjankowiak who may have a better understanding of the relationships between these inference algorithms. |
Thank you for your comments @fritzo. See below my feedback:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code looks great. Thanks for adding tests!
@@ -31,53 +34,58 @@ def _guess_max_plate_nesting(model, args, kwargs): | |||
return max_plate_nesting | |||
|
|||
|
|||
class _predictiveResults(NamedTuple): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the cleanup! This will also help us with #2550
The Problem
When sampling from the posterior predictive distribution we are often using a guide as an approximation for the posterior. As mentioned in #3340 it is often desirable to correct for the non-uniform per sample gap between the model log-probability and the guide log-probability. This gap is essentially the weight that should be assigned to each sample.
The current implementation of
pyro.infer.predictive.Predictive
does not support calculation of these weights.The Proposed Solution
Add
pyro.infer.predictive.WeighedPredictive
which supports calculation of per sample weights.The implementation relies on three objects:
pyro.infer.predictive.Predictive
).pyro.infer.predictive.Predictive
).pyro.infer.predictive.WeighedPredictive
when called as the keyword argumentmodel_guide
(as in the model that was used when creating the guide).The
model_guide
is what enables calculation of the weights. If not provided we use the model provided at instantiation ofpyro.infer.predictive.WeighedPredictive
as themodel_guide
(in this case the model provided at instantiation is usually already with observations constrained to be the actual observations).Design Considerations
pyro.infer.predictive.Predictive
.pyro.infer.predictive.Predictive
when implementingpyro.infer.predictive.WeighedPredictive
.